package cn.zhxu.toys.concurrent;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Stream;

/**
 * 平行机调度器
 * @author Troy.Zhou
 * @since v0.3.3
 */
public class ParallelTaskScheduler implements ParallelScheduler, ExecutorService {

	
	private static final Logger log = LoggerFactory.getLogger(ParallelTaskScheduler.class);
	
    private int batchSize = 20;

    private final ThreadPoolExecutor exePool;

    private final AtomicInteger totalTasks;
	
    private boolean verbose;

    public ParallelTaskScheduler() {
    	this(100);
	}

    public ParallelTaskScheduler(int corePoolSize) {
    	this(new ScheduledThreadPoolExecutor(corePoolSize));
	}
    
    public ParallelTaskScheduler(ThreadPoolExecutor exePool) {
    	this.exePool = exePool;
    	this.totalTasks = new AtomicInteger(0);
	}
    
	/**
     * 同步调度平行机
     * @param concurrency 最大并发量
     * @param provider 任务提供者
     * @param executor 任务执行器
     */
    @Override
    public <T> void schedule(int concurrency, TaskProvider<T> provider, TaskExecutor<T> executor) {
    	schedule(concurrency, provider, executor, null);
    }

	/**
     * 同步调度平行机
     * @param concurrency 最大并发量
     * @param provider 任务提供者
     * @param executor 任务执行器
     * @param identify 任务标识器，用于对provider提供的任务进行去重
     */
    @Override
    public <T> void schedule(int concurrency, TaskProvider<T> provider, TaskExecutor<T> executor, Identify<T> identify) {
        int page = 0;
        List<Future<?>> futures = new LinkedList<>();
        List<T> tasks = provider.getTaskList(page, batchSize);
        int size = tasks.size();
        long maxId = 0;
        while (size > 0) {
        	Stream<T> taskStream = tasks.stream();
            // 任务去重
            if (maxId > 0) {
            	long finalMaxId = maxId;
            	taskStream = taskStream.filter(t -> identify.id(t) > finalMaxId);
            }
            Queue<T> taskQueue = toQueue(taskStream);
            if (taskQueue.size() > 0) {
            	if (identify != null) {
            		Optional<T> maxOpt = taskQueue.stream()
							.max((t1, t2) -> (int) (identify.id(t1) - identify.id(t2)));
					maxId = identify.id(maxOpt.get());
				}
                // 执行任务
                submitTasks(concurrency, taskQueue, executor, futures);
            }
            // 分页加载
            if (size < batchSize) {
            	break;
            }
            page++;
            tasks = provider.getTaskList(page, batchSize);
            size = tasks.size();
        }
        waitDone(futures);
    }
    
	/**
     * 异步调度平行机
     * @param concurrency 最大并发量
     * @param provider 任务提供者
     * @param executor 任务执行器
     */
    @Override
    public <T> Future<?> asyncSchedule(int concurrency, TaskProvider<T> provider, TaskExecutor<T> executor) {
    	return asyncSchedule(concurrency, provider, executor, null);
    }
    
	/**
     * 异步调度平行机
     * @param concurrency 最大并发量
     * @param provider 任务提供者
     * @param executor 任务执行器
     * @param identify 任务标识器，用于对provider提供的任务进行去重
     */
    @Override
    public <T> Future<?> asyncSchedule(int concurrency, TaskProvider<T> provider, TaskExecutor<T> executor, Identify<T> identify) {
		return exePool.submit(() -> schedule(concurrency, provider, executor, identify));
    }

	@Override
	public void execute(Runnable command) {
		exePool.execute(command);
	}

	@Override
	public void shutdown() {
		exePool.shutdown();
	}

	@Override
	public List<Runnable> shutdownNow() {
		return exePool.shutdownNow();
	}

	@Override
	public boolean isShutdown() {
		return exePool.isShutdown();
	}

	@Override
	public boolean isTerminated() {
		return exePool.isTerminated();
	}

	@Override
	public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
		return exePool.awaitTermination(timeout, unit);
	}

	@Override
	public <T> Future<T> submit(Callable<T> task) {
		return exePool.submit(task);
	}

	@Override
	public <T> Future<T> submit(Runnable task, T result) {
		return exePool.submit(task, result);
	}

	@Override
	public Future<?> submit(Runnable task) {
		return exePool.submit(task);
	}

	@Override
	public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
		return exePool.invokeAll(tasks);
	}

	@Override
	public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
			throws InterruptedException {
		return exePool.invokeAll(tasks, timeout, unit);
	}

	@Override
	public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
		return exePool.invokeAny(tasks);
	}

	@Override
	public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
			throws InterruptedException, ExecutionException, TimeoutException {
		return exePool.invokeAny(tasks, timeout, unit);
	}

	private <T> Queue<T> toQueue(Stream<T> taskStream) {
		Queue<T> taskQueue = new LinkedList<>();
		taskStream.forEach(new Consumer<T>() {
			@Override
			public void accept(T t) {
				taskQueue.add(t);
			}
		});
		return taskQueue;
	}

    protected <T> void submitTasks(int concurrency, Queue<T> tasks, TaskExecutor<T> executor, List<Future<?>> futures) {
        while (tasks.size() > 0) {
            while (futures.size() < concurrency && tasks.size() > 0) {
                T task = tasks.poll();
                futures.add(exePool.submit(new Runnable() {
                    @Override
					public void run() {
                    	if (verbose) {
                    		int tc = totalTasks.incrementAndGet();
                    		int ac = exePool.getActiveCount();
                            log.info("平行机：" + tc + ", " + ac);
                    	}
                        executor.execute(task);
                        if (verbose) {
                        	int tc = totalTasks.decrementAndGet();
                    		int ac = exePool.getActiveCount();
                            log.info("平行机：" + tc + ", " + ac);
                        }
                    }
                }));
            }
            if (tasks.size() > 0) {
                try {
					Thread.sleep(1);
				} catch (InterruptedException e) {
					throw new RuntimeException(e.getMessage(), e);
				}
                for (int i = futures.size() - 1; i >= 0; i--) {
                	Future<?> f = futures.get(i);
                	if (f.isDone()) {
                        futures.remove(i);
                    }
                }
            }
        }
    }
    
	private void waitDone(List<Future<?>> futures) {
		futures.forEach(new Consumer<Future<?>>() {
			@Override
			public void accept(Future<?> f) {
				try {
					f.get();
				} catch (InterruptedException | ExecutionException e) {
					throw new RuntimeException(e.getMessage(), e);
				}
			}
		});
	}
    
	public void setBatchSize(int batchSize) {
		this.batchSize = batchSize;
	}

	public void setVerbose(boolean verbose) {
		this.verbose = verbose;
	}

}
